
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pickle
from Utils import *


class Infer_Net_Wei(nn.Module):
    """
    Log-norm inference network for topic proportion
    """
    def __init__(self, v=2000, d_hidden=[300,300,300], topic_list=[128,64,32]):
        super(Infer_Net_Wei, self).__init__()
        self.v = v
        self.d_hidden = d_hidden
        self.k = topic_list
        self.layer_num = len(topic_list)
        hidden_layer = [nn.Linear(v, d_hidden[0]), nn.Linear(d_hidden[0], d_hidden[0])] + [nn.Linear(d_hidden[i], d_hidden[i+1]) for i in range(len(d_hidden)-1)]
        self.hidden_layer = nn.ModuleList(hidden_layer)
        self.bn_layer = nn.ModuleList([nn.BatchNorm1d(d_hidden[0])] + [nn.BatchNorm1d(hidden_size) for hidden_size in d_hidden])
        self.dropout = nn.Dropout(p=0.1)
        encoder = [nn.Linear(h, 2*k, bias=True) for h, k in zip(d_hidden, topic_list)]
        self.encoder = nn.ModuleList(encoder)

    def res_block(self, x, layer_num):
        ### res block for hidden path
        x1 = self.hidden_layer[layer_num](x)
        try:
            out = x + x1
        except:
            out = x1
        return self.dropout(F.relu(self.bn_layer[layer_num](out)))

    def reparameterize(self, wei_shape, wei_scale, sample_num=5):
        """
        :param wei_shape: batch, k
        :param wei_scale: batch, k
        :return: Weibull reparameterization
        """
        eps = torch.rand(sample_num, wei_shape.shape[0], wei_shape.shape[1], device=wei_shape.device)
        theta = torch.unsqueeze(wei_scale, axis=0).repeat(sample_num, 1, 1) * torch.pow(-torch.log(eps+1e-10),
                                    torch.unsqueeze(1 / wei_shape, axis=0).repeat(sample_num, 1, 1))
        return torch.mean(torch.clamp(theta, 1e-10, 100.0), dim=0, keepdim=False)   ### for Nan case

    def forward(self, x):
        """
        :param x: document bow vector, batch, v
        :return: unnormalized topic proportions
        """
        x_embed = self.hidden_layer[0](x)
        hidden_list = [0] * self.layer_num
        theta_list = [0] * self.layer_num
        for t in range(self.layer_num):
            if t == 0:
                hidden = self.res_block(x_embed, t+1)
            else:
                hidden = self.res_block(hidden_list[t-1], t+1)

            k, l = torch.chunk(F.softplus(self.encoder[t](hidden)), 2, dim=1)

            k = torch.clamp(k, 0.1, 100.0)
            l = torch.clamp(l, 1e-4, 1e4)

            theta = self.reparameterize(k, l)
            hidden_list[t] = hidden
            theta_list[t] = theta
        return theta_list

class Infer_Net_norm(nn.Module):
    """
    Log-norm inference network for topic proportion
    """
    def __init__(self, v=2000, d_hidden=[300,300,300], topic_list=[128,64,32]):
        super(Infer_Net_norm, self).__init__()
        self.v = v
        self.d_hidden = d_hidden
        self.k = topic_list
        self.layer_num = len(topic_list)
        hidden_layer = [nn.Linear(v, d_hidden[0]), nn.Linear(d_hidden[0], d_hidden[0])] + [nn.Linear(d_hidden[i], d_hidden[i+1]) for i in range(len(d_hidden)-1)]
        self.hidden_layer = nn.ModuleList(hidden_layer)
        self.bn_layer = nn.ModuleList([nn.BatchNorm1d(d_hidden[0])] + [nn.BatchNorm1d(hidden_size) for hidden_size in d_hidden])
        self.dropout = nn.Dropout(p=0.1)
        encoder = [nn.Linear(h, 2*k, bias=True) for h, k in zip(d_hidden, topic_list)]
        self.encoder = nn.ModuleList(encoder)

    def res_block(self, x, layer_num):
        ### res block for hidden path
        x1 = self.hidden_layer[layer_num](x)
        try:
            out = x + x1
        except:
            out = x1
        return self.dropout(F.relu(self.bn_layer[layer_num](out)))


    def reparameterize(self, mu, logvar):
        """Returns a sample from a Gaussian distribution via reparameterization.
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps.mul_(std).add_(mu)

    def forward(self, x):
        """
        :param x: document bow vector, batch, v
        :return: unnormalized topic proportions
        """
        x_embed = self.hidden_layer[0](x)
        hidden_list = [0] * self.layer_num
        theta_list = [0] * self.layer_num
        for t in range(self.layer_num):
            if t == 0:
                hidden = self.res_block(x_embed, t+1)
            else:
                hidden = self.res_block(hidden_list[t-1], t+1)

            mu, logsigma = torch.chunk(self.encoder[t](hidden), 2, dim=1)
            z = self.reparameterize(mu, logsigma)
            theta = F.softmax(z, dim=-1)
            hidden_list[t] = hidden
            theta_list[t] = theta

        return theta_list


class WeTe(nn.Module):
    """
    """
    def __init__(self, args, voc=None):
        super(WeTe, self).__init__()
        self.topic_k = args.K
        self.hidden = args.H
        self.layer_num = len(self.topic_k)
        self.voc_size = args.vocsize
        self.h = args.embedding_dim
        self.beta = args.beta
        self.epsilon = args.epsilon
        self.real_min = torch.tensor(1e-30)
        self.init_alpha = args.init_alpha
        self.device = args.device
        self.voc = voc
        self.alpha = [0] * len(self.topic_k)

        self.topic_id = [torch.tensor([[i] for i in range(topic_k)], device=self.device) for topic_k in self.topic_k]
        self.word_id = torch.tensor([[i] for i in range(self.voc_size)], device=self.device)
        self.topic_layer = nn.ModuleList([nn.Embedding(topic_k, self.h).to(self.device) for topic_k in self.topic_k])
        self.word_layer = nn.Embedding(self.voc_size, self.h).to(self.device)
        self.InferNet = Infer_Net_Wei(v=self.voc_size, d_hidden=self.hidden, topic_list=self.topic_k)

        self.init_topic(glove=args.glove)
        self.update_embeddings()

    def init_topic(self, glove=None):
        """
        :param glove: Path to pretrained glove embedding
        :return:
        """
        if glove is not None:
            print(f'Load pretrained glove embeddings from : {glove}')
            word_e = np.array(np.random.rand(self.voc_size, self.h) * 0.01, dtype=np.float32)
            num_trained = 0
            for line in open(glove, encoding='UTF-8').readlines():
                sp = line.split()
                if len(sp) == self.h + 1:
                    if sp[0] in self.voc:
                        num_trained += 1
                        word_e[self.voc.index(sp[0])] = [float(x) for x in sp[1:]]
            print(f'num-trained in voc_size: {num_trained}|{self.voc_size}: {1.0 * num_trained / self.voc_size}')
        else:
            print(f'initialize word embedding from N(0, 0.02)')
            word_e = np.random.normal(0, 0.02, size=(self.voc_size, self.h))

        if self.init_alpha:
            cluster_center = [0] * len(self.topic_k)
            for layer_id, k in enumerate(self.topic_k):
                if layer_id == 0:
                    cluster_center[layer_id] = cluster_kmeans(word_e, k)
                else:
                    cluster_center[layer_id] = cluster_kmeans(cluster_center[layer_id-1], k)
                self.topic_layer[layer_id] = self.topic_layer[layer_id].from_pretrained(torch.from_numpy(cluster_center[layer_id]).float(), freeze=False).to(self.device)
        else:
            for layer_id, k in enumerate(self.topic_k):
                topic_e = np.random.normal(0, 0.1 * layer_id, size=(self.topic_k, self.h))
                self.topic_layer[layer_id] = self.topic_layer[layer_id].from_pretrained(torch.from_numpy(topic_e).float(), freeze=False).to(self.device)
        self.word_layer = self.word_layer.from_pretrained(torch.from_numpy(word_e).float(), freeze=False).to(self.device)

    def save_embeddings(self, path='out.pkl'):
        word_e = self.rho.cpu().detach().numpy()
        topic_e = [alpha.cpu().detach().numpy() for alpha in self.alpha]
        with open(path, 'wb') as f:
            pickle.dump([word_e, topic_e], f)

    def update_embeddings(self):
        self.rho = self.word_layer(self.word_id).squeeze()
        for layer_id, k in enumerate(self.topic_k):
            self.alpha[layer_id] = self.topic_layer[layer_id](self.topic_id[layer_id]).squeeze()

    def cal_phi(self):
        phi = [F.softmax(torch.matmul(self.rho, self.alpha[0].t()), dim=0)]
        for layer_id in range(len(self.topic_k)-1):
            phi.append(F.softmax(torch.matmul(self.alpha[layer_id], self.alpha[layer_id+1].t()), dim=0))
        return phi

    def cal_phi_layer(self, layer_id):
        phi_layer = F.softmax(torch.matmul(self.rho, self.alpha[layer_id].t()), dim=0)
        return phi_layer

    def cost_ct(self, inner_p, cost_c, x, theta):
        """
        :param inner_p: v, k
        :param cost_c: v, k
        :param x: batch of sequential words
        :param theta: batch, k, topic proportions
        :return: bi-direction cost
        """
        dis_d = torch.clamp(torch.exp(inner_p), 1e-30, 1e10)
        forward_cost = 0.
        backward_cost = 0.
        theta_norm = F.softmax(theta, dim=-1)
        for each, each_theta in zip(x, theta_norm):
            forward_doc_dis = dis_d[each] * each_theta[None, :]   ## N_j * K
            doc_dis = dis_d[each]  ## N_J * K
            forward_pi = forward_doc_dis / (torch.sum(forward_doc_dis, dim=1, keepdim=True) + self.real_min)  ### N_j, K
            backward_pi = doc_dis / (torch.sum(doc_dis, dim=0, keepdim=True) + self.real_min)  ### N_j, K
            forward_cost += (cost_c[each] * forward_pi).sum(1).mean()
            backward_cost += ((cost_c[each] * backward_pi).sum(0) * each_theta).sum()
        return forward_cost, backward_cost

    def GCT(self, theta1, theta2, alpha1, alpha2):
        """
        :param theta1: batch, k1
        :param theta2: batch, k2
        :param alpha1: k1, d
        :param alpha2: k2, d
        :return: ct loss, from theta1 to theta2, and from theta2 to theta1
        """
        inner_p = torch.matmul(alpha1, alpha2.t())       #### k1, k2
        similarity = torch.clamp(torch.exp(inner_p), 1e-30, 1e10)        #### k1, k2
        cost_c = torch.clamp(torch.exp(-inner_p), 1e-30, 1e10)      #### k1, k2
        forward_cost = 0.
        backward_cost = 0.
        ####  remove some topics for speed

        for each1, each2 in zip(theta1, theta2):
            activate1 = torch.where(each1)[0]
            activate2 = torch.where(each2)[0]
            each1 = each1[activate1]
            each2 = each2[activate2]
            each1 = each1 / each1.sum()
            each2 = each2 / each2.sum()
            similarity_ = similarity[activate1][:, activate2]
            cost_c_ = cost_c[activate1][:, activate2]
            forward_s = similarity_ * each2[None]
            forward_pi = forward_s / (torch.sum(forward_s, dim=1, keepdim=True) + 1e-20)
            backward_s = similarity_ * each1[:, None]
            backward_pi = backward_s / (torch.sum(backward_s, dim=0, keepdim=True) + 1e-20)
            forward_cost += ((cost_c_ * forward_pi).sum(1) * each1).sum()
            backward_cost += ((cost_c_ * backward_pi).sum(0) * each2).sum()
        return forward_cost, backward_cost

    def Poisson_likelihood(self, x, re_x):
        """
        :param x: batch of bow vector
        :param re_x: \Phi \times \theta
        :return: Negative log of poisson likelihoood
        """
        return -(x * torch.log(re_x + 1e-10) - re_x - torch.lgamma(x + 1.0)).sum(-1).mean()

    def Entropy(self, x, re_x):
        """
        :param x: batch of bow vector
        :param re_x: \Phi \times \theta
        :return: Negative log of poisson likelihoood
        """
        x_norm = x / torch.sum(x, dim=-1, keepdim=True)
        re_x_norm = re_x / torch.sum(re_x, dim=-1, keepdim=True)
        return -100*(x_norm * torch.log(re_x_norm + 1e-10)).sum(-1).mean()

    def Entropy_v1(self, x, re_x):
        """
        :param x: batch of bow vector
        :param re_x: \Phi \times \theta
        :return: Negative log of poisson likelihoood
        """
        # x_norm = x
        # re_x_norm = re_x / torch.sum(re_x, dim=-1, keepdim=True)
        return -200*(x * torch.log(re_x + 1e-10)).sum(-1).mean()

    def Likelihood(self, x, re_x):
        """
        :param x: batch of bow vector
        :param re_x: \Phi \times \theta
        :return: Negative log of poisson likelihoood
        """
        return -(x * torch.log(re_x + 1e-6)).sum(1).mean()

    def forward(self, x, bow):
        theta = self.InferNet(bow)
        self.update_embeddings()
        phi = self.cal_phi()
        ## calculate distance between word and topic embeddings
        inner_p = torch.matmul(self.rho, self.alpha.t())
        cost_c = torch.clamp(torch.exp(-inner_p), 1e-30, 1e10)
        forward_cost, backward_cost = self.cost_ct(inner_p, cost_c, x, theta)
        re_x = torch.matmul(phi, theta.t())
        TM_cost = self.Poisson_likelihood(bow, re_x.t())
        loss = self.beta * forward_cost + (1-self.beta) * backward_cost + self.epsilon * TM_cost
        return loss, forward_cost, backward_cost, TM_cost, theta



class WeTe_image(nn.Module):
    """
    """
    def __init__(self, args, concept_embeddings, voc=None):
        super(WeTe_image, self).__init__()
        self.topic_k = args.K
        self.hidden = args.H
        self.layer_num = len(self.topic_k)
        self.voc_size = args.vocsize
        self.h = args.embedding_dim
        self.beta = args.beta
        self.epsilon = args.epsilon
        self.real_min = torch.tensor(1e-30)
        self.init_alpha = args.init_alpha
        self.device = args.device
        self.voc = voc
        self.alpha = [0] * len(self.topic_k)

        self.topic_id = [torch.tensor([[i] for i in range(topic_k)], device=self.device) for topic_k in self.topic_k]
        self.word_id = torch.tensor([[i] for i in range(self.voc_size)], device=self.device)
        self.topic_layer = nn.ModuleList([nn.Embedding(topic_k, self.h).to(self.device) for topic_k in self.topic_k])
        self.word_layer = nn.Embedding(self.voc_size, self.h).to(self.device)
        self.InferNet = Infer_Net_Wei(v=self.voc_size, d_hidden=self.hidden, topic_list=self.topic_k)

        self.init_topic(glove=concept_embeddings)
        self.update_embeddings()

    def init_topic(self, glove=None):
        """
        :param glove: Path to pretrained glove embedding
        :return:
        """
        if glove is not None:
            print(f'load concept embeddings')
            word_e = glove
        if self.init_alpha:
            cluster_center = [0] * len(self.topic_k)
            for layer_id, k in enumerate(self.topic_k):
                if layer_id == 0:
                    cluster_center[layer_id] = cluster_kmeans(word_e, k)
                else:
                    cluster_center[layer_id] = cluster_kmeans(cluster_center[layer_id-1], k)
                self.topic_layer[layer_id] = self.topic_layer[layer_id].from_pretrained(torch.from_numpy(cluster_center[layer_id]).float(), freeze=False).to(self.device)
        else:
            for layer_id, k in enumerate(self.topic_k):
                topic_e = np.random.normal(0, 0.1 * layer_id, size=(self.topic_k, self.h))
                self.topic_layer[layer_id] = self.topic_layer[layer_id].from_pretrained(torch.from_numpy(topic_e).float(), freeze=False).to(self.device)
        self.word_layer = self.word_layer.from_pretrained(torch.from_numpy(word_e).float(), freeze=True).to(self.device)

    def save_embeddings(self, path='out.pkl'):
        word_e = self.rho.cpu().detach().numpy()
        topic_e = [alpha.cpu().detach().numpy() for alpha in self.alpha]
        with open(path, 'wb') as f:
            pickle.dump([word_e, topic_e], f)

    def update_embeddings(self):
        self.rho = self.word_layer(self.word_id).squeeze()
        for layer_id, k in enumerate(self.topic_k):
            self.alpha[layer_id] = self.topic_layer[layer_id](self.topic_id[layer_id]).squeeze()

    def cal_phi(self):
        phi = [F.softmax(torch.matmul(self.rho, self.alpha[0].t()), dim=0)]
        for layer_id in range(len(self.topic_k)-1):
            phi.append(F.softmax(torch.matmul(self.alpha[layer_id], self.alpha[layer_id+1].t()), dim=0))
        return phi

    def cal_phi_layer(self, layer_id):
        phi_layer = F.softmax(torch.matmul(self.rho, self.alpha[layer_id].t()), dim=0)
        return phi_layer

    def cost_ct(self, inner_p, cost_c, x, theta):
        """
        :param inner_p: v, k
        :param cost_c: v, k
        :param x: batch of sequential words
        :param theta: batch, k, topic proportions
        :return: bi-direction cost
        """
        dis_d = torch.clamp(torch.exp(inner_p), 1e-30, 1e10)
        forward_cost = 0.
        backward_cost = 0.
        theta_norm = F.softmax(theta, dim=-1)
        for each, each_theta in zip(x, theta_norm):
            forward_doc_dis = dis_d[each] * each_theta[None, :]   ## N_j * K
            doc_dis = dis_d[each]  ## N_J * K
            forward_pi = forward_doc_dis / (torch.sum(forward_doc_dis, dim=1, keepdim=True) + self.real_min)  ### N_j, K
            backward_pi = doc_dis / (torch.sum(doc_dis, dim=0, keepdim=True) + self.real_min)  ### N_j, K
            forward_cost += (cost_c[each] * forward_pi).sum(1).mean()
            backward_cost += ((cost_c[each] * backward_pi).sum(0) * each_theta).sum()
        return forward_cost, backward_cost

    def GCT(self, theta1, theta2, alpha1, alpha2):
        """
        :param theta1: batch, k1
        :param theta2: batch, k2
        :param alpha1: k1, d
        :param alpha2: k2, d
        :return: ct loss, from theta1 to theta2, and from theta2 to theta1
        """
        inner_p = torch.matmul(alpha1, alpha2.t())       #### k1, k2
        similarity = torch.clamp(torch.exp(inner_p), 1e-30, 1e10)        #### k1, k2
        cost_c = torch.clamp(torch.exp(-inner_p), 1e-30, 1e10)      #### k1, k2
        forward_cost = 0.
        backward_cost = 0.
        ####  remove some topics for speed

        for each1, each2 in zip(theta1, theta2):
            activate1 = torch.where(each1)[0]
            activate2 = torch.where(each2)[0]
            each1 = each1[activate1]
            each2 = each2[activate2]
            each1 = each1 / each1.sum()
            each2 = each2 / each2.sum()
            similarity_ = similarity[activate1][:, activate2]
            cost_c_ = cost_c[activate1][:, activate2]
            forward_s = similarity_ * each2[None]
            forward_pi = forward_s / (torch.sum(forward_s, dim=1, keepdim=True) + 1e-20)
            backward_s = similarity_ * each1[:, None]
            backward_pi = backward_s / (torch.sum(backward_s, dim=0, keepdim=True) + 1e-20)
            forward_cost += ((cost_c_ * forward_pi).sum(1) * each1).sum()
            backward_cost += ((cost_c_ * backward_pi).sum(0) * each2).sum()
        return forward_cost, backward_cost

    def Poisson_likelihood(self, x, re_x):
        """
        :param x: batch of bow vector
        :param re_x: \Phi \times \theta
        :return: Negative log of poisson likelihoood
        """
        return -(x * torch.log(re_x + 1e-10) - re_x - torch.lgamma(x + 1.0)).sum(-1).mean()

    def Entropy(self, x, re_x):
        """
        :param x: batch of bow vector
        :param re_x: \Phi \times \theta
        :return: Negative log of poisson likelihoood
        """
        x_norm = x / torch.sum(x, dim=-1, keepdim=True)
        re_x_norm = re_x / torch.sum(re_x, dim=-1, keepdim=True)
        return -100*(x_norm * torch.log(re_x_norm + 1e-10)).sum(-1).mean()

    def Entropy_v1(self, x, re_x):
        """
        :param x: batch of bow vector
        :param re_x: \Phi \times \theta
        :return: Negative log of poisson likelihoood
        """
        # x_norm = x
        # re_x_norm = re_x / torch.sum(re_x, dim=-1, keepdim=True)
        return -200*(x * torch.log(re_x + 1e-10)).sum(-1).mean()

    def Likelihood(self, x, re_x):
        """
        :param x: batch of bow vector
        :param re_x: \Phi \times \theta
        :return: Negative log of poisson likelihoood
        """
        return -(x * torch.log(re_x + 1e-6)).sum(1).mean()

    def forward(self, x, bow):
        theta = self.InferNet(bow)
        self.update_embeddings()
        phi = self.cal_phi()
        ## calculate distance between word and topic embeddings
        inner_p = torch.matmul(self.rho, self.alpha.t())
        cost_c = torch.clamp(torch.exp(-inner_p), 1e-30, 1e10)
        forward_cost, backward_cost = self.cost_ct(inner_p, cost_c, x, theta)
        re_x = torch.matmul(phi, theta.t())
        TM_cost = self.Poisson_likelihood(bow, re_x.t())
        loss = self.beta * forward_cost + (1-self.beta) * backward_cost + self.epsilon * TM_cost
        return loss, forward_cost, backward_cost, TM_cost, theta

